{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# IMDB Training Notebook"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Knet.KnetArray{Float32,N} where N"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "using Knet\n",
    "# Hyperparams LSTM\n",
    "EPOCHS=3\n",
    "BATCHSIZE=64\n",
    "EMBEDSIZE=125\n",
    "NUMHIDDEN=100\n",
    "LR=0.0001\n",
    "BETA_1=0.9\n",
    "BETA_2=0.999\n",
    "EPS=1e-08\n",
    "MAXLEN=150 #maximum size of the word sequence\n",
    "MAXFEATURES=30000 #vocabulary size\n",
    "DROPOUT=0.35\n",
    "SEED=1311194\n",
    "gpu(0)\n",
    "atype = gpu()<0 ? Array{Float32}:KnetArray{Float32}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "initmodel (generic function with 1 method)"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#define model\"\n",
    "function initmodel()\n",
    "    rnnSpec,rnnWeights = rnninit(EMBEDSIZE,NUMHIDDEN; rnnType=:lstm)\n",
    "    inputMatrix = atype(xavier(Float32,EMBEDSIZE,MAXFEATURES))\n",
    "    outputMatrix = atype(xavier(Float32,2,NUMHIDDEN))\n",
    "    return rnnSpec,(rnnWeights,inputMatrix,outputMatrix)\n",
    "end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "savemodel (generic function with 1 method)"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "function savemodel(weights,rnnSpec;localfile=\"model_imdb.jld\")\n",
    "    save(localfile,\"weights\",weights,\"rnnSpec\",rnnSpec)\n",
    "end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# define loss and its gradient\n",
    "function predict(weights, inputs, rnnSpec;train=false)\n",
    "    rnnWeights, inputMatrix, outputMatrix = weights # (1,1,W), (X,V), (2,H)\n",
    "    indices = hcat(inputs...)' # (B,T)\n",
    "    rnnInput = inputMatrix[:,indices] # (X,B,T)\n",
    "    if train\n",
    "        rnnInput = dropout(rnnInput,DROPOUT)\n",
    "    end\n",
    "    rnnOutput = rnnforw(rnnSpec, rnnWeights, rnnInput)[1] # (H,B,T)\n",
    "    if train\n",
    "        rnnOutput = dropout(rnnOutput,DROPOUT)\n",
    "    end\n",
    "    return outputMatrix * rnnOutput[:,:,end] # (2,H) * (H,B) = (2,B)\n",
    "end\n",
    "\n",
    "loss(w,x,y,r;train=false)=nll(predict(w,x,r;train=train),y)\n",
    "lossgradient = grad(loss);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[1m\u001b[36mINFO: \u001b[39m\u001b[22m\u001b[36mDownloading IMDB...\n",
      "\u001b[39m  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current\n",
      "                                 Dload  Upload   Total   Spent    Left  Speed\n",
      "100 16.6M  100 16.6M    0     0   758k      0  0:00:22  0:00:22 --:--:--  671k\n",
      "  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current\n",
      "                                 Dload  Upload   Total   Spent    Left  Speed\n",
      "100 1602k  100 1602k    0     0   819k      0  0:00:01  0:00:01 --:--:--  819k\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " 94.515819 seconds (27.42 M allocations: 1.331 GiB, 1.89% gc time)\n",
      "25000-element Array{Array{Int32,1},1}\n",
      "25000-element Array{Int8,1}\n",
      "25000-element Array{Array{Int32,1},1}\n",
      "25000-element Array{Int8,1}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[1m\u001b[36mINFO: \u001b[39m\u001b[22m\u001b[36mLoading IMDB...\n",
      "\u001b[39m"
     ]
    }
   ],
   "source": [
    "# load data\n",
    "include(\"imdb.jl\")\n",
    "@time (xtrn,ytrn,xtst,ytst,imdbdict)=imdb(maxlen=MAXLEN,maxval=MAXFEATURES,seed=SEED)\n",
    "for d in (xtrn,ytrn,xtst,ytst); println(summary(d)); end\n",
    "imdbarray = Array{String}(88584)\n",
    "for (k,v) in imdbdict; imdbarray[v]=k; end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sample review:\n",
      "reilly reilly reilly reilly reilly reilly reilly 5hrs to start with i have to point out the fact that you're gonna feel completely lost for more than half an hour yeah some things happen but you don't know why or what for when you finally figure things out you just realize that it's nothing but a twisted spastic opera dealing with mature prostitutes dead mothers illegitimate sons the characters are rather poor and the actors specially the young ones don't help that much to spastic look credible only marisa spastic stands out but she's a superb actress no matter if the movie is pure rubbish br br the only positive things to say about spastic sol de spastic is that spastic pablo spastic seems to have good intentions and he's filmed a couple of scenes that are quite intense well maybe the next time br br my rate 4 10\n",
      "\n",
      "Classification: 1\n"
     ]
    }
   ],
   "source": [
    "rnd = rand(1:length(xtrn))\n",
    "println(\"Sample review:\\n\",join(imdbarray[xtrn[rnd]],\" \"),\"\\n\")\n",
    "println(\"Classification: \",join(ytrn[rnd]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# prepare for training\n",
    "weights = nothing; knetgc(); # Reclaim memory from previous run\n",
    "rnnSpec,weights = initmodel()\n",
    "optim = optimizers(weights, Adam; lr=LR, beta1=BETA_1, beta2=BETA_2, eps=EPS);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[1m\u001b[36mINFO: \u001b[39m\u001b[22m\u001b[36mTraining...\n",
      "\u001b[39m"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " 13.892884 seconds (2.06 M allocations: 136.613 MiB, 8.12% gc time)\n",
      "  3.815176 seconds (401.21 k allocations: 46.601 MiB, 27.58% gc time)\n",
      "  3.774615 seconds (401.94 k allocations: 46.613 MiB, 27.07% gc time)\n",
      " 21.487306 seconds (2.87 M allocations: 229.977 MiB, 14.90% gc time)\n"
     ]
    }
   ],
   "source": [
    "# 29s\n",
    "info(\"Training...\")\n",
    "@time for epoch in 1:EPOCHS\n",
    "    @time for (x,y) in minibatch(xtrn,ytrn,BATCHSIZE;shuffle=true)\n",
    "        grads = lossgradient(weights,x,y,rnnSpec;train=true)\n",
    "        update!(weights, grads, optim)\n",
    "    end\n",
    "end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[1m\u001b[36mINFO: \u001b[39m\u001b[22m\u001b[36mTesting...\n",
      "\u001b[39m"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  3.328266 seconds (884.90 k allocations: 77.894 MiB, 2.24% gc time)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "0.8670272435897436"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "info(\"Testing...\")\n",
    "@time accuracy(weights, minibatch(xtst,ytst,BATCHSIZE), (w,x)->predict(w,x,rnnSpec))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "savemodel(weights,rnnSpec)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Julia 0.6.4",
   "language": "julia",
   "name": "julia-0.6"
  },
  "language_info": {
   "file_extension": ".jl",
   "mimetype": "application/julia",
   "name": "julia",
   "version": "0.6.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
